{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CIFAR-10 convolutional network\n",
    "\n",
    "This is a small CIFAR-10 convolutional neural network designed to run on one\n",
    "Loihi core. Because of these size constraints, it is not particularly\n",
    "powerful, and does not achieve anywhere near state-of-the-art results on the\n",
    "task. Nevertheless, the network performs well enough to demonstrate that\n",
    "Loihi is capable of hosting larger, more powerful object recognition networks\n",
    "than MNIST.\n",
    "\n",
    "This notebook requires:\n",
    "- nengo 2.8.0 (``pip install 'nengo==2.8.0'``)\n",
    "- nengo_dl 1.2.1 (``pip install 'nengo-dl==1.2.1'``)\n",
    "- nengo_extras (``pip install nengo-extras``)\n",
    "- nengo_loihi dev (``pip install git+https://github.com/nengo/nengo-loihi.git``)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "from functools import partial\n",
    "import os\n",
    "import tempfile\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import nengo\n",
    "from nengo.utils.compat import is_iterable\n",
    "import nengo_dl\n",
    "from nengo_dl import SoftLIFRate\n",
    "from nengo_extras.data import load_cifar10, one_hot_from_labels\n",
    "import nengo_loihi\n",
    "from nengo_loihi.conv import (\n",
    "    Conv2D, ImageSlice, ImageShape, split_transform)\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "if nengo.__version__ != '2.8.0':\n",
    "    nengo_version = nengo.__version__\n",
    "    nengo = None\n",
    "    raise ImportError(\"Nengo version is not correct (must be 2.8.0, got %s)\"\n",
    "                      % nengo_version)\n",
    "\n",
    "if tuple(nengo_dl.__version__.split('.')[:3]) != ('1', '2', '1'):\n",
    "    nengo_dl_version = nengo_dl.__version__\n",
    "    nengo_dl = None\n",
    "    raise ImportError(\"nengo_dl version is not correct (must be 1.2.1, got %s)\"\n",
    "                      % nengo_dl_version)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TfConv2d(object):\n",
    "    KERNEL_IDX = 0\n",
    "\n",
    "    def __init__(self, name, n_filters,\n",
    "                 kernel_size=(3, 3), strides=(1, 1), initializer=None):\n",
    "        self.name = name\n",
    "        self.n_filters = n_filters\n",
    "        self.kernel_size = kernel_size if is_iterable(kernel_size) else (\n",
    "            kernel_size, kernel_size)\n",
    "        self.strides = strides if is_iterable(strides) else (strides, strides)\n",
    "        self.padding = 'VALID'\n",
    "        self.initializer = initializer\n",
    "\n",
    "        self.kernel = None\n",
    "        self.shape_in = None\n",
    "\n",
    "    def __str__(self):\n",
    "        return '%s(%s)' % (type(self).__name__, self.name)\n",
    "\n",
    "    def output_shape(self, input_shape=None):\n",
    "        if input_shape is None:\n",
    "            assert self.shape_in is not None\n",
    "            input_shape = self.shape_in\n",
    "        conv2d = Conv2D(\n",
    "            self.n_filters,\n",
    "            input_shape,\n",
    "            kernel_size=self.kernel_size,\n",
    "            strides=self.strides,\n",
    "        )\n",
    "        return conv2d.output_shape\n",
    "\n",
    "    def pre_build(self, shape_in, shape_out):\n",
    "        assert isinstance(self.shape_in, ImageShape)\n",
    "        assert shape_in[1] == self.shape_in.size\n",
    "        ni, nj, nc = self.shape_in.shape(channels_last=True)\n",
    "        nf = self.n_filters\n",
    "        si, sj = self.kernel_size\n",
    "        self.kernel = tf.get_variable('kernel_%s' % self.name,\n",
    "                                      shape=(si, sj, nc, nf),\n",
    "                                      initializer=self.initializer)\n",
    "\n",
    "    def __call__(self, t, x):\n",
    "        batch_size = x.get_shape()[0].value\n",
    "        x = tf.reshape(x, (batch_size,) + self.shape_in.shape())\n",
    "\n",
    "        # make strides 1 for examples, channels\n",
    "        channels_last = self.shape_in.channels_last\n",
    "        strides = ((1,) + self.strides + (1,) if channels_last else\n",
    "                   (1, 1) + self.strides)\n",
    "\n",
    "        data_format = 'NHWC' if channels_last else 'NCHW'\n",
    "        y = tf.nn.conv2d(x, self.kernel, strides, self.padding,\n",
    "                         data_format=data_format)\n",
    "        return tf.reshape(y, (batch_size, -1))\n",
    "\n",
    "    def save_params(self, sess):\n",
    "        self.kernel_value = sess.run(self.kernel)\n",
    "\n",
    "    def transform(self):\n",
    "        assert self.shape_in is not None\n",
    "        kernel = self.kernel_value\n",
    "        kernel = np.transpose(kernel, (2, 0, 1, 3))  # (nc, si, sj, nf)\n",
    "        return Conv2D.from_kernel(kernel, self.shape_in,\n",
    "                                  strides=self.strides,\n",
    "                                  mode=self.padding.lower(),\n",
    "                                  correlate=True)\n",
    "\n",
    "\n",
    "class TfDense(object):\n",
    "    def __init__(self, name, n_outputs):\n",
    "        self.name = name\n",
    "        self.n_outputs = n_outputs\n",
    "\n",
    "        self.weights = None\n",
    "        self.shape_in = None\n",
    "\n",
    "    def output_shape(self, input_shape):\n",
    "        return ImageShape(1, 1, self.n_outputs, channels_last=True)\n",
    "\n",
    "    def pre_build(self, shape_in, shape_out):\n",
    "        assert isinstance(self.shape_in, ImageShape)\n",
    "        assert shape_in[1] == self.shape_in.size\n",
    "        assert shape_out[1] == self.n_outputs\n",
    "        n_inputs = shape_in[1]\n",
    "        n_outputs = self.n_outputs\n",
    "        self.weights = tf.get_variable('weights_%s' % self.name,\n",
    "                                       shape=(n_inputs, n_outputs))\n",
    "\n",
    "    def __call__(self, t, x):\n",
    "        x = tf.reshape(x, (x.shape[0], -1))\n",
    "        return tf.tensordot(x, self.weights, axes=[[1], [0]])\n",
    "\n",
    "    def save_params(self, sess):\n",
    "        self.weights_value = sess.run(self.weights)\n",
    "\n",
    "    def transform(self):\n",
    "        weights = self.weights_value\n",
    "        return weights.T\n",
    "\n",
    "\n",
    "def crossentropy(outputs, targets):\n",
    "    \"\"\"Cross-entropy loss function (for training).\"\"\"\n",
    "    return tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits_v2(\n",
    "        logits=outputs, labels=targets))\n",
    "\n",
    "\n",
    "def classification_error(outputs, targets):\n",
    "    \"\"\"Classification error function (for testing).\"\"\"\n",
    "    return 100 * tf.reduce_mean(\n",
    "        tf.cast(tf.not_equal(tf.argmax(outputs[:, -1], axis=-1),\n",
    "                             tf.argmax(targets[:, -1], axis=-1)),\n",
    "                tf.float32))\n",
    "\n",
    "\n",
    "def percentile_rate_l2_loss(x, y, weight=1.0, target=0.0, percentile=99.):\n",
    "    # x axes are (batch examples, time (==1), neurons)\n",
    "    assert len(x.shape) == 3\n",
    "    rates = tf.contrib.distributions.percentile(x, percentile, axis=(0, 1))\n",
    "    return weight * tf.nn.l2_loss(rates - target)\n",
    "\n",
    "\n",
    "def percentile_l2_loss_range(\n",
    "        x, y,weight=1.0, min=0.0, max=np.inf, percentile=99.):\n",
    "    # x axes are (batch examples, time (==1), neurons)\n",
    "    assert len(x.shape) == 3\n",
    "    neuron_p = tf.contrib.distributions.percentile(x, percentile, axis=(0, 1))\n",
    "    low_error = tf.maximum(0.0, min - neuron_p)\n",
    "    high_error = tf.maximum(0.0, neuron_p - max)\n",
    "    return weight * tf.nn.l2_loss(low_error + high_error)\n",
    "\n",
    "\n",
    "def has_checkpoint(checkpoint_base):\n",
    "    checkpoint_dir, checkpoint_name = os.path.split(checkpoint_base)\n",
    "    if not os.path.exists(checkpoint_dir):\n",
    "        return False\n",
    "\n",
    "    files = os.listdir(checkpoint_dir)\n",
    "    files = [f for f in files if f.startswith(checkpoint_name)]\n",
    "    return len(files) > 0\n",
    "\n",
    "\n",
    "def get_layer_rates(sim, input_data, rate_probes, amplitude=None):\n",
    "    \"\"\"Collect firing rates on internal layers\"\"\"\n",
    "    assert len(input_data) == 1\n",
    "    in_p, in_x = next(iter(input_data.items()))\n",
    "    assert in_x.ndim == 3\n",
    "    n_steps = in_x.shape[1]\n",
    "\n",
    "    tmpdir = tempfile.TemporaryDirectory()\n",
    "    sim.save_params(os.path.join(tmpdir.name, \"tmp\"),\n",
    "                    include_local=True, include_global=False)\n",
    "\n",
    "    sim.run_steps(n_steps,\n",
    "                  input_feeds=input_data,\n",
    "                  progress_bar=False)\n",
    "\n",
    "    rates = [sim.data[p] for p in rate_probes]\n",
    "    if amplitude is not None:\n",
    "        rates = [rate / amplitude for rate in rates]\n",
    "\n",
    "    sim.load_params(os.path.join(tmpdir.name, \"tmp\"),\n",
    "                    include_local=True, include_global=False)\n",
    "    tmpdir.cleanup()\n",
    "\n",
    "    return rates\n",
    "\n",
    "\n",
    "def get_outputs(sim, input_data, output_probe):\n",
    "    assert len(input_data) == 1\n",
    "    in_p, in_x = next(iter(input_data.items()))\n",
    "    assert in_x.ndim == 3\n",
    "    n_steps = in_x.shape[1]\n",
    "\n",
    "    tmpdir = tempfile.TemporaryDirectory()\n",
    "    sim.save_params(os.path.join(tmpdir.name, \"tmp\"),\n",
    "                    include_local=True, include_global=False)\n",
    "\n",
    "    sim.run_steps(n_steps,\n",
    "                  input_feeds=input_data,\n",
    "                  progress_bar=False)\n",
    "\n",
    "    outs = sim.data[output_probe]\n",
    "\n",
    "    sim.load_params(os.path.join(tmpdir.name, \"tmp\"),\n",
    "                    include_local=True, include_global=False)\n",
    "    tmpdir.cleanup()\n",
    "\n",
    "    return outs\n",
    "\n",
    "\n",
    "class ImageShifter(object):\n",
    "    \"\"\"Randomly shift (translate) images, for data augmentation.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    shape : ImageShape\n",
    "        Shape of the images to be shifted.\n",
    "    shift : int or (int, int) (Default: 0)\n",
    "        Maximum shift in vertical and horizontal directions.\n",
    "    flip : boolean (Default: False)\n",
    "        Whether to flip the images horizontally.\n",
    "    shift_std : float or (float, float) (Default: 1)\n",
    "        Standard deviation of the shift in vertical and horizontal directions.\n",
    "    images : ndarray (n_images, d0, d1, d2) (Default: None)\n",
    "        Default images for member functions, if no other images are provided.\n",
    "    rng : np.random.RandomState\n",
    "        Default random number generator for member functions.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, shape, shift=0, flip=False, shift_std=1., images=None,\n",
    "                 rng=np.random):\n",
    "        if images is not None:\n",
    "            assert images.shape[-3:] == shape.shape()\n",
    "\n",
    "        self.images = images\n",
    "        self.shape = shape\n",
    "        self.shift = shift if is_iterable(shift) else (shift, shift)\n",
    "        self.shift_std = shift_std if is_iterable(shift_std) else (\n",
    "            shift_std, shift_std)\n",
    "        self.flip = flip\n",
    "        self.rng = rng\n",
    "\n",
    "        shift_i, shift_j = self.shift\n",
    "        self.output_shape = ImageShape(\n",
    "            self.shape.rows - 2*shift_i, self.shape.cols - 2*shift_j,\n",
    "            self.shape.channels, channels_last=self.shape.channels_last)\n",
    "\n",
    "    def augment(self, images=None, rng=None):\n",
    "        \"\"\"Randomly shift (translate) images, for data augmentation.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        images : ndarray (n_images, d0, d1, d2) (Default: None)\n",
    "            Images to shift. If ``shape.channels_last``, then ``(d0, d1, d2)``\n",
    "            will be height, width, and channels; otherwise, they will be\n",
    "            channels, height, and width. If ``None``, use ``self.images``.\n",
    "        rng : np.random.RandomState\n",
    "            Random number generator for picking shifts for each image. If\n",
    "            ``None``, use ``self.rng``.\n",
    "        \"\"\"\n",
    "        images = self.images if images is None else images\n",
    "        rng = self.rng if rng is None else rng\n",
    "        assert images.shape[-3:] == self.shape.shape()\n",
    "        in_shape = images.shape[-3:]\n",
    "        out_shape = self.output_shape.shape()\n",
    "        batch_shape = images.shape[:-3]\n",
    "        batch_size = np.prod(batch_shape)\n",
    "\n",
    "        nyi, nyj = self.output_shape.rows, self.output_shape.cols\n",
    "        si, sj = self.shift\n",
    "        std_i, std_j = self.shift_std\n",
    "        pi = np.exp(-0.5 * (np.arange(-si, si+1) / (std_i * si))**2)\n",
    "        pj = np.exp(-0.5 * (np.arange(-sj, sj+1) / (std_j * sj))**2)\n",
    "        pi /= pi.sum()\n",
    "        pj /= pj.sum()\n",
    "        ii = rng.choice(np.arange(len(pi)), p=pi, size=batch_size)\n",
    "        jj = rng.choice(np.arange(len(pj)), p=pj, size=batch_size)\n",
    "        ff = rng.randint(0, 2 if self.flip else 1, size=batch_size)\n",
    "\n",
    "        out = np.zeros(batch_shape + out_shape, dtype=images.dtype)\n",
    "        images_flat = images.reshape((batch_size,) + in_shape)\n",
    "        out_flat = out.reshape((batch_size,) + out_shape)\n",
    "\n",
    "        ch_last = self.shape.channels_last\n",
    "        for k, (i, j, f) in enumerate(zip(ii, jj, ff)):\n",
    "            image = images_flat[k]\n",
    "            image = (image[i:i+nyi, j:j+nyj, :] if ch_last else\n",
    "                     image[:, i:i+nyi, j:j+nyj])\n",
    "            if f:\n",
    "                image = image[:, ::-1, :] if ch_last else image[:, :, ::-1]\n",
    "            out_flat[k] = image\n",
    "\n",
    "        return out\n",
    "\n",
    "    def center(self, images=None):\n",
    "        \"\"\"Take the center of each image (in lieu of shifting).\n",
    "\n",
    "        This is typically used at test time, to preprocess the test images\n",
    "        so that they are the same size as the shifted training images.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        images : ndarray (n_images, d0, d1, d2) (Default: None)\n",
    "            Images to shift. If ``shape.channels_last``, then ``(d0, d1, d2)``\n",
    "            will be height, width, and channels; otherwise, they will be\n",
    "            channels, height, and width. If ``None``, use ``self.images``.\n",
    "        \"\"\"\n",
    "        images = self.images if images is None else images\n",
    "        assert images.shape[-3:] == self.shape.shape()\n",
    "        si, sj = self.shift\n",
    "        ni, nj = self.shape.rows, self.shape.cols\n",
    "        return (images[..., si:ni-si, sj:nj-sj, :] if self.shape.channels_last\n",
    "                else images[..., :, si:ni-si, sj:nj-sj])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- load dataset\n",
    "(X_train, y_train), (X_test, y_test), label_names = load_cifar10(\n",
    "    label_names=True)\n",
    "X_train = X_train.reshape(-1, 3, 32, 32).astype('float32')\n",
    "X_test = X_test.reshape(-1, 3, 32, 32).astype('float32')\n",
    "\n",
    "# basic normalize to [0, 1]\n",
    "# X_train = X_train / 255.\n",
    "# X_test = X_test / 255.\n",
    "\n",
    "# basic normalize to [-1, 1]\n",
    "X_train = (X_train - 127.5) / 127.5\n",
    "X_test = (X_test - 127.5) / 127.5\n",
    "\n",
    "n_classes = len(label_names)\n",
    "T_train = one_hot_from_labels(y_train, n_classes)\n",
    "T_test = one_hot_from_labels(y_test, n_classes)\n",
    "\n",
    "channels_last = False\n",
    "X_shape = ImageShape(32, 32, 3, channels_last=channels_last)\n",
    "if channels_last:\n",
    "    X_train = np.transpose(X_train, (0, 2, 3, 1))\n",
    "    X_test = np.transpose(X_test, (0, 2, 3, 1))\n",
    "\n",
    "X_min = X_train.min()\n",
    "X_max = X_train.max()\n",
    "print(\"X range: %0.3f, %0.3f\" % (X_min, X_max))\n",
    "\n",
    "# data params and data augmentation\n",
    "minibatch_size = 256\n",
    "shifter = ImageShifter(X_shape, shift=3, flip=True,\n",
    "                       rng=np.random.RandomState(1))\n",
    "input_shape = shifter.output_shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- specify network parameters\n",
    "checkpoint_base = './cifar10_convnet_params'\n",
    "\n",
    "max_rate = 100\n",
    "amp = 1. / max_rate\n",
    "rate_reg = 1e-2\n",
    "rate_target = max_rate * amp  # must be in amplitude scaled units\n",
    "\n",
    "default_ann_out_min = -10\n",
    "default_ann_out_max = 5\n",
    "neuron_type = SoftLIFRate(amplitude=amp, sigma=0.01)\n",
    "# neuron_type = nengo.RectifiedLinear(amplitude=amp)\n",
    "\n",
    "layer_dicts = [\n",
    "    dict(layer_func=TfConv2d('layer1', 4, kernel_size=1),\n",
    "         neuron_type=nengo.RectifiedLinear(amplitude=amp),\n",
    "         on_chip=False, no_min_rate=True),\n",
    "    dict(layer_func=TfConv2d('layer2', 64, strides=2), \n",
    "         neuron_type=neuron_type),\n",
    "    dict(layer_func=TfConv2d('layer3', 96), \n",
    "         neuron_type=neuron_type),\n",
    "    dict(layer_func=TfConv2d('layer4', 128, strides=2), \n",
    "         neuron_type=neuron_type),\n",
    "    dict(layer_func=TfConv2d('layer5', 128, kernel_size=1), \n",
    "         neuron_type=neuron_type),\n",
    "    dict(layer_func=TfDense('layer_out', 10)),\n",
    "]\n",
    "\n",
    "# --- build the nengo_dl network\n",
    "objective = {}\n",
    "rate_probes = collections.OrderedDict()\n",
    "with nengo.Network(seed=0) as net:\n",
    "    net.config[nengo.Ensemble].max_rates = nengo.dists.Choice([max_rate])\n",
    "    net.config[nengo.Ensemble].intercepts = nengo.dists.Choice([0])\n",
    "    net.config[nengo.Connection].synapse = None\n",
    "\n",
    "    nengo_dl.configure_settings(trainable=False)\n",
    "\n",
    "    # the input node that will be used to feed in input images\n",
    "    inp = nengo.Node([0] * input_shape.size, label='input')\n",
    "\n",
    "    layers = []\n",
    "    shape_in = input_shape\n",
    "    x = inp\n",
    "    for layer_dict in layer_dicts:\n",
    "        layer_dict = dict(layer_dict)  # so we can pop\n",
    "        layer_func = layer_dict.pop('layer_func', None)\n",
    "        layer_neuron = layer_dict.pop('neuron_type', None)\n",
    "        on_chip = layer_dict.pop('on_chip', True)\n",
    "        no_min_rate = layer_dict.pop('no_min_rate', False)\n",
    "\n",
    "        shape_out = None\n",
    "        fn_layer = None\n",
    "        neuron_layer = None\n",
    "\n",
    "        if layer_func is not None:\n",
    "            assert shape_in\n",
    "            layer_func.shape_in = shape_in\n",
    "            shape_out = layer_func.output_shape(shape_in)\n",
    "\n",
    "            size_in = shape_in.size if shape_in else x.size_out\n",
    "            size_out = shape_out.size if shape_out else size_in\n",
    "            y = nengo_dl.TensorNode(layer_func,\n",
    "                                    size_in=size_in, size_out=size_out,\n",
    "                                    label=layer_func.name)\n",
    "            nengo.Connection(x, y)\n",
    "            x = y\n",
    "            fn_layer = x\n",
    "\n",
    "        if layer_neuron is not None:\n",
    "            y = nengo.Ensemble(x.size_out, 1, neuron_type=layer_neuron).neurons\n",
    "            y.image_shape = shape_out if shape_out is not None else shape_in\n",
    "            y.on_chip = on_chip\n",
    "            nengo.Connection(x, y)\n",
    "            x = y\n",
    "            neuron_layer = x\n",
    "\n",
    "            yp = nengo.Probe(y)\n",
    "            if no_min_rate:\n",
    "                objective[yp] = partial(\n",
    "                    percentile_l2_loss_range, weight=10*rate_reg,\n",
    "                    min=0, max=rate_target, percentile=99.9)\n",
    "            else:\n",
    "                objective[yp] = partial(\n",
    "                    percentile_l2_loss_range, weight=rate_reg,\n",
    "                    min=0.75*rate_target, max=rate_target, percentile=99.9)\n",
    "            rate_probes[layer_func] = yp\n",
    "\n",
    "        shape_in = shape_out\n",
    "        layers.append((fn_layer, neuron_layer))\n",
    "        print(\"Added layer %s %s shape=%s size=%s\" % (\n",
    "            layer_func, layer_neuron, shape_out, shape_out.size))\n",
    "\n",
    "    out_p = nengo.Probe(x)\n",
    "    out_p_filt = nengo.Probe(x, synapse=0.1)\n",
    "\n",
    "objective[out_p] = crossentropy\n",
    "\n",
    "# set up training/test data\n",
    "train_inputs = {inp: shifter.center(X_train).reshape(X_train.shape[0], 1, -1)}\n",
    "train_targets = {out_p: T_train.reshape(T_train.shape[0], 1, -1)}\n",
    "test_inputs = {inp: shifter.center(X_test).reshape(X_test.shape[0], 1, -1)}\n",
    "test_targets = {out_p: T_test.reshape(T_test.shape[0], 1, -1)}\n",
    "rate_inputs = {inp: shifter.center(X_train[:minibatch_size]).reshape(\n",
    "    minibatch_size, 1, -1)}\n",
    "\n",
    "# train our network in NengoDL\n",
    "with nengo_dl.Simulator(net, minibatch_size=minibatch_size) as sim:\n",
    "    if has_checkpoint(checkpoint_base):\n",
    "        sim.load_params(checkpoint_base)\n",
    "    else:\n",
    "        print(\"Test error before training: %.2f%%\" %\n",
    "              sim.loss(test_inputs, test_targets, classification_error))\n",
    "\n",
    "        # run training\n",
    "        n_epochs = 30\n",
    "        optimizer = tf.train.RMSPropOptimizer(learning_rate=0.001)\n",
    "        for _ in range(n_epochs):\n",
    "            train_augmented = {\n",
    "                inp: shifter.augment(X_train).reshape(X_train.shape[0], 1, -1)}\n",
    "            sim.train(train_augmented, train_targets, optimizer,\n",
    "                      objective=objective, n_epochs=1)\n",
    "            print(\"Test error after training: %.2f%%\" %\n",
    "                  sim.loss(test_inputs, test_targets, classification_error))\n",
    "\n",
    "        sim.save_params(checkpoint_base)\n",
    "\n",
    "    ann_test_preds = None\n",
    "    try:\n",
    "        print(\"Train error after training: %.2f%%\" %\n",
    "              sim.loss(train_inputs, train_targets, classification_error))\n",
    "        print(\"Test error after training: %.2f%%\" %\n",
    "              sim.loss(test_inputs, test_targets, classification_error))\n",
    "\n",
    "        mini_test_inputs = {\n",
    "            inp: shifter.center(X_test[:minibatch_size]).reshape(\n",
    "                minibatch_size, 1, -1)}\n",
    "        ann_test_outs = get_outputs(sim, mini_test_inputs, out_p)\n",
    "        ann_test_preds = np.argmax(ann_test_outs[:, 0, :], axis=-1)\n",
    "\n",
    "        rates = get_layer_rates(sim, rate_inputs, rate_probes.values(),\n",
    "                                amplitude=amp)\n",
    "        for layer_func, rate in zip(rate_probes, rates):\n",
    "            print(\"%s rate: mean=%0.3f, 99th: %0.3f\" % (\n",
    "                layer_func, rate.mean(), np.percentile(rate, 99)))\n",
    "\n",
    "        # compute output range\n",
    "        outs = get_outputs(sim, rate_inputs, out_p)\n",
    "        print(\"Output range: min=%0.3f, 1st=%0.3f, 99th=%0.3f, max=%0.3f\" % (\n",
    "            outs.min(), np.percentile(outs, 1), np.percentile(outs, 99), \n",
    "            outs.max()))\n",
    "        ann_out_min = np.percentile(outs, 1)\n",
    "        ann_out_max = np.percentile(outs, 99)\n",
    "    except Exception:\n",
    "        print(\"Could not compute ANN values on this machine\")\n",
    "\n",
    "        ann_out_min = default_ann_out_min\n",
    "        ann_out_max = default_ann_out_max\n",
    "\n",
    "    # store trained parameters back into the network\n",
    "    for fn_layer, _ in layers:\n",
    "        fn_layer.tensor_func.save_params(sim.sess)\n",
    "\n",
    "del net, sim, x  # so we don't accidentally use them"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Spiking network\n",
    "presentation_images = shifter.center(X_test).reshape(X_test.shape[0], -1)\n",
    "presentation_time = 0.2\n",
    "present_images = nengo.processes.PresentInput(\n",
    "    presentation_images, presentation_time)\n",
    "\n",
    "with nengo.Network() as loihi_net:\n",
    "    nengo_loihi.add_params(loihi_net)  # allow setting on_chip\n",
    "\n",
    "    loihi_net.config[nengo.Ensemble].max_rates = nengo.dists.Choice([100])\n",
    "    loihi_net.config[nengo.Ensemble].intercepts = nengo.dists.Choice([0])\n",
    "    loihi_net.config[nengo.Connection].synapse = 0.01  # can also try 0.005\n",
    "\n",
    "    inp = nengo.Node(present_images, label='input')\n",
    "    inp_p = nengo.Probe(inp)\n",
    "    xx = [inp]\n",
    "    xslices = [ImageSlice(input_shape)]\n",
    "    cores_used = 0\n",
    "\n",
    "    for fn_layer, neuron_layer in layers:\n",
    "        func = fn_layer.tensor_func\n",
    "        name = func.name\n",
    "\n",
    "        # --- create neuron layer\n",
    "        yy = []\n",
    "        yslices = []\n",
    "        if neuron_layer is not None:\n",
    "            layer_neurons = neuron_layer.ensemble.neuron_type\n",
    "            if isinstance(layer_neurons, SoftLIFRate):\n",
    "                layer_neurons = nengo.LIF(\n",
    "                    tau_rc=layer_neurons.tau_rc,\n",
    "                    tau_ref=layer_neurons.tau_ref,\n",
    "                    amplitude=layer_neurons.amplitude)\n",
    "            elif isinstance(layer_neurons, nengo.RectifiedLinear):\n",
    "                layer_neurons = nengo.SpikingRectifiedLinear(\n",
    "                    amplitude=layer_neurons.amplitude)\n",
    "            else:\n",
    "                raise ValueError(\"Unsupported neuron type %s\" % layer_neurons)\n",
    "\n",
    "            image_shape = neuron_layer.image_shape\n",
    "            if not neuron_layer.on_chip:\n",
    "                y = nengo.Ensemble(\n",
    "                    image_shape.size, 1,\n",
    "                    neuron_type=layer_neurons,\n",
    "                    label=\"%s\" % (func.name,))\n",
    "                loihi_net.config[y].on_chip = False\n",
    "                yy.append(y.neurons)\n",
    "                yslices.append(ImageSlice(image_shape))\n",
    "            else:\n",
    "                split_slices = image_shape.split_channels(\n",
    "                    max_size=1024, max_channels=8)\n",
    "                for image_slice in split_slices:\n",
    "                    assert image_slice.size <= 1024\n",
    "                    idxs = image_slice.channel_idxs()\n",
    "                    y = nengo.Ensemble(\n",
    "                        image_slice.size, 1,\n",
    "                        neuron_type=layer_neurons,\n",
    "                        label=\"%s_%d:%d\" % (func.name, min(idxs), max(idxs)))\n",
    "                    yy.append(y.neurons)\n",
    "                    yslices.append(image_slice)\n",
    "                    cores_used += 1\n",
    "        else:\n",
    "            output_shape = func.output_shape(func.shape_in)\n",
    "            assert output_shape.size == fn_layer.size_out\n",
    "            max_rate = 100.  # can also try 300.\n",
    "            gain = max_rate / (ann_out_max - ann_out_min)\n",
    "            bias = -gain * ann_out_min\n",
    "            y = nengo.Ensemble(output_shape.size, 1, label=func.name,\n",
    "                               neuron_type=nengo.SpikingRectifiedLinear(),\n",
    "                               gain=nengo.dists.Choice([gain]),\n",
    "                               bias=nengo.dists.Choice([bias]))\n",
    "            yy.append(y.neurons)\n",
    "            yslices.append(ImageSlice(output_shape))\n",
    "            out_p = nengo.Probe(y.neurons, synapse=nengo.Alpha(0.02))\n",
    "            cores_used += 1\n",
    "\n",
    "        assert len(yy) == len(yslices)\n",
    "\n",
    "        # --- create function layer\n",
    "        transform = func.transform()\n",
    "        for xi, (x, xslice) in enumerate(zip(xx, xslices)):\n",
    "            for yi, (y, yslice) in enumerate(zip(yy, yslices)):\n",
    "                transform_xy = split_transform(transform, xslice, yslice)\n",
    "                nengo.Connection(x, y, transform=transform_xy)\n",
    "\n",
    "        xx = yy\n",
    "        xslices = yslices\n",
    "\n",
    "print(\"Used %d cores\" % cores_used)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the spiking neural network on Loihi\n",
    "\n",
    "Now, we run the spiking model on Loihi (or in the emulator if ``nxsdk`` is \n",
    "not installed). For demonstration purposes, we only run 10 examples, but \n",
    "feel free to run 100 or more examples if you wish to get a better idea of \n",
    "the network accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_presentations = 10\n",
    "\n",
    "with nengo_loihi.Simulator(loihi_net, precompute=False) as sim:\n",
    "    if 'loihi' in sim.sims:\n",
    "        sim.sims['loihi'].snip_max_spikes_per_step = 1000\n",
    "\n",
    "    sim.run(n_presentations * presentation_time)\n",
    "\n",
    "class_output = sim.data[out_p]\n",
    "steps_per_pres = int(presentation_time / sim.dt)\n",
    "preds = []\n",
    "for i in range(0, class_output.shape[0], steps_per_pres):\n",
    "    c = class_output[i:i + steps_per_pres]\n",
    "    c = c[int(0.7 * steps_per_pres):]  # take last part\n",
    "    pred = np.argmax(c.sum(axis=0))\n",
    "    preds.append(pred)\n",
    "\n",
    "print(\"Predictions: %s\" % (preds,))\n",
    "if ann_test_preds is not None:\n",
    "    print(\"ANN:         %s\" % (list(ann_test_preds[:n_presentations]),))\n",
    "print(\"Actual:      %s\" % (list(y_test[:n_presentations]),))\n",
    "error = (preds != y_test[:n_presentations]).mean()\n",
    "print(\"Accuracy: %0.3f%%, Error: %0.3f%%\" % (100 - 100*error, 100*error))\n",
    "if ann_test_preds is not None:\n",
    "    ann_match = (preds == ann_test_preds[:n_presentations]).mean()\n",
    "    print(\"ANN match: %0.3f%%\" % (100*ann_match,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- fancy plots\n",
    "plt.figure(figsize=(12, 6))\n",
    "\n",
    "plt.subplot(2, 1, 1)\n",
    "images = (X_test if X_shape.channels_last \n",
    "          else np.transpose(X_test, (0, 2, 3, 1)))\n",
    "ni, nj, nc = images[0].shape\n",
    "allimage = np.zeros((ni, nj*n_presentations, nc))\n",
    "for i, image in enumerate(images[:n_presentations]):\n",
    "    allimage[:, i*nj:(i + 1)*nj] = image\n",
    "if allimage.shape[-1] == 1:\n",
    "    allimage = allimage[:, :, 0]\n",
    "allimage = (allimage - X_min) / (X_max - X_min)\n",
    "plt.imshow(allimage, aspect='auto', interpolation='none', cmap='gray')\n",
    "\n",
    "plt.subplot(2, 1, 2)\n",
    "t = sim.trange()\n",
    "plt.plot(t, sim.data[out_p])\n",
    "plt.xlim([t[0], t[-1]])\n",
    "plt.legend([l.decode() for l in label_names],\n",
    "           loc='upper right',\n",
    "           bbox_to_anchor=(1.18, 1.05))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
